# coding:utf-8
'''
author:wangyi
'''

import tensorflow as tf
import numpy as np
from sklearn.cluster import KMeans
import pickle

def source_net():
    '''
    构建一个可训练的网络,并保存,权重为w
    :return:
    '''

    x_datas = np.random.rand(100, 4)
    y_datas = 2*x_datas

    input_x = tf.placeholder(dtype=tf.float32,shape=[None,4])
    input_y = tf.placeholder(dtype=tf.float32,shape=[None,4])

    w = tf.Variable(initial_value=tf.truncated_normal(shape=[4,4],stddev=0.1),name='w')
    b = tf.Variable(tf.constant([0.0],shape=[4]),name='b')

    logits = tf.matmul(input_x,w)+b

    loss = tf.reduce_mean(tf.square(logits-input_y))

    op = tf.train.AdamOptimizer(1e-3).minimize(loss)

    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    for i in range(100):
        sess.run(op,feed_dict={input_x:x_datas,input_y:y_datas})
        print('loss',sess.run(loss,feed_dict={input_x:x_datas,input_y:y_datas}))
    saver = tf.train.Saver(tf.global_variables())
    saver.save(sess,save_path='./qt_output/best_model')
    sess.close()

def quant_process():
    '''
    获取网络权重,并根据量化流程获取索引矩阵,并保存
    :return:
    '''
    sess = tf.Session()
    sess.run(tf.global_variables_initializer())

    saver = tf.train.import_meta_graph('./qt_output/best_model.meta')
    saver.restore(sess,save_path='./qt_output/best_model')

    graph = tf.get_default_graph()
    # 获取原网络权重
    w = graph.get_tensor_by_name('w:0')
    b = graph.get_tensor_by_name('b:0')
    w_src = w.eval(session=sess)
    b_src = b.eval(session=sess)
    # 将矩阵转化成向量
    w_src_r = np.reshape(w_src,[w_src.shape[0]*w_src.shape[1],1])
    # 聚类
    y_pred = KMeans(n_clusters=4, random_state=9).fit_predict(w_src_r)

    # 构建聚类中心和权重实值的映射表
    value2cls = {w_src_r[i][0]:y_pred[i] for i in range(len(w_src_r))}

    # 获取权重矩阵并保存
    for i in range(len(w_src)):
        for j in range(len(w_src[i])):
            w_src[i][j] = value2cls[w_src[i][j]]

    pickle.dump([w_src,b_src],open('w_b_src.pkl','wb'))
    sess.close()






def update_net():
    '''
    重新构建网络的前向结构,使用量化矩阵作为权重,此时不需要再构建loss和优化器,只构建前向网络结构就行
    :return:
    '''
    # 获取量化矩阵
    tf.reset_default_graph()
    w_src,b_src = pickle.load(open('w_b_src.pkl','rb'))
    input_x = tf.placeholder(dtype=tf.float32, shape=[None, 4])
    input_y = tf.placeholder(dtype=tf.float32, shape=[None, 4])

    # 将量化矩阵赋值给网络这一位置的权重
    w = tf.Variable(initial_value=tf.constant(w_src), name='w')

    b = tf.Variable(initial_value=tf.constant(b_src), name='b')
    logits = tf.matmul(input_x, w) + b

    # 将重构的网络结构保存
    saver = tf.train.Saver()
    with tf.Session() as sess:
        sess.run(tf.initialize_all_variables())
        saver.save(sess,'./qt_output/qt/best_model_qt')




def test_qt():
    '''
    测试量化矩阵是否被赋值
    :return:
    '''

    #恢复重构网络
    tf.reset_default_graph()
    saver = tf.train.import_meta_graph('./qt_output/qt/best_model_qt.meta')

    with tf.Session() as sess:
        saver.restore(sess,'./qt_output/qt/best_model_qt')
        # 打印此时的权重矩阵,结果为量化矩阵,说明已经赋值给tensor成功
        print(sess.run(tf.get_default_graph().get_tensor_by_name('w:0')))





if __name__ == '__main__':

    source_net()
    quant_process()
    update_net()
    test_qt()

